Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llama model sdpa attention forward function masking bug when output_attentions=True #30652

Merged
merged 14 commits into from
May 15, 2024

Conversation

Aladoro
Copy link
Contributor

@Aladoro Aladoro commented May 4, 2024

What does this PR do?

Very simple fix to a nasty issue I have recently encountered. Due to its simplicity, I opened a PR directly without raising an issue first to avoid redundancy. Please, let me know if I should also raise an issue, and I'll do that right away.

Description

When output_attentions is True, sdpa implementation's forward method calls the eager implementation's forward method. However, a None mask is still returned if sdpa's 'AttentionMaskConverter._ignore_causal_mask_sdpa' returns true (which occurs whenever the input is unmasked, as sdpa would defer the causal masking to the sdpa Pytorch implementation).
This inconsistency causes the model to run the eager implementation with no causal attention mask if the original input is unmasked (e.g., if a single input sequence is encoded or all encoded input sequences have the same length) and requires_attn=True.

Tagging @ArthurZucker and @younesbelkada

@Aladoro
Copy link
Contributor Author

Aladoro commented May 4, 2024

A minimal example of this erroneous behavior can be reproduced via:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             device_map='cuda', 
                                             torch_dtype=torch.bfloat16
                                             )

tokenizer.pad_token_id = tokenizer.eos_token_id

inputs = tokenizer(["Today is the day I went to the store and ..."],
                    return_tensors="pt").to('cuda')

expanded_batch_size = 1


outputs = model.generate(
    input_ids = inputs['input_ids'].expand(expanded_batch_size, -1),
    attention_mask = inputs['attention_mask'].expand(expanded_batch_size, -1),
    do_sample=False,
    max_new_tokens=5, 
    return_dict_in_generate=True,
    )


input_length = inputs.input_ids.shape[1]
sequences= outputs.sequences

for sequence in sequences:
    decoded_sequence = tokenizer.decode(sequence)
    print(decoded_sequence)

# separator
print('-'*20)


outputs = model.generate(
    input_ids = inputs['input_ids'].expand(expanded_batch_size, -1),
    attention_mask = inputs['attention_mask'].expand(expanded_batch_size, -1),
    do_sample=False,
    max_new_tokens=5, 
    return_dict_in_generate=True,
    output_attentions=True, # ?!
    )


input_length = inputs.input_ids.shape[1]
sequences= outputs.sequences

# garbage generated outputs since no masking is applied
for sequence in sequences:
    decoded_sequence = tokenizer.decode(sequence)
    print(decoded_sequence)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch.

  1. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) line 1127 needs to be ignored as well.
  2. we need to add your small example script as a test! 🤗

@Aladoro
Copy link
Contributor Author

Aladoro commented May 6, 2024

Great catch.

  1. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) line 1127 needs to be ignored as well.
  2. we need to add your small example script as a test! 🤗

@ArthurZucker Thanks for reviewing my pull request and all your work in maintaining this awesome repo! :) Regarding your comments:

  1. Done.
  2. Let me know if you would like me to make a small testing script for this bug myself! (i.e., check that generated outputs with the 'eager' implementation match the generated outputs with output_attentions=True, although inherent stochasticity in the GPU kernels might make it difficult to always get 100% consistent results).

p.s. There seem to be some CircleCI tests failing on the main branch... which are now failing after I merged.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 6, 2024

For 2. the test is already implemented, but I don't think it tests output_attention=True. It probably a matter of adding a parametrized. See this file here: (and the generate tests) https://github.com/huggingface/transformers/blob/main/tests/test_modeling_common.py#L3590.

Potentially adding output_attention to make sure sdpa with output attention matches eager with or without (which it is supposed to!)

@ArthurZucker
Copy link
Collaborator

Feel free to rebase it might be fixed on main / be flaky

@Aladoro
Copy link
Contributor Author

Aladoro commented May 6, 2024

Feel free to rebase it might be fixed on main / be flaky

Just did :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Aladoro
Copy link
Contributor Author

Aladoro commented May 7, 2024

@ArthurZucker Let me know if you think this fix is ready for merging, or if you'd like to add the tests to the same PR!

@ArthurZucker ArthurZucker mentioned this pull request May 9, 2024
4 tasks
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to just add the test in this PR 😉

@Aladoro
Copy link
Contributor Author

Aladoro commented May 12, 2024

Would be nice to just add the test in this PR 😉

Alright - I made the addition of output_attentions=True to the sdpa equivalence test, as you suggested ;) (Black code re-formatting seems to have messed up the diff, but the changes are minimal...)

@ArthurZucker - Let me know if there are any outstanding issues or if there is something else missing before merging ^^

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's make sure you rebase as Gemma was updated a bit and commit with [run-slow] so that slow tests are run!

tests/test_modeling_common.py Outdated Show resolved Hide resolved
Aladoro and others added 2 commits May 15, 2024 15:10
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, fyi @fxmarty @gante and @ydshieh

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented May 15, 2024

(Merging once the CIs are all green!)

@Aladoro
Copy link
Contributor Author

Aladoro commented May 15, 2024

@ArthurZucker thanks for your suggestions! I also propagated the same changes to the new jetmoe model. All default checks are now passing ^^

@ArthurZucker ArthurZucker merged commit 4b3eb19 into huggingface:main May 15, 2024
22 checks passed
@ArthurZucker
Copy link
Collaborator

THanks for the fix

@Aladoro Aladoro deleted the fix-llama-mask-output-attn branch May 15, 2024 23:32
itazap pushed a commit that referenced this pull request May 24, 2024
…ut_attentions=True (#30652)

* Fix llama model forward function with attention=True, same-length encoded sequence.

* Fix style

* propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama)

* Fix style

* ignore unnecessary sdpa mask converter when output_attentions=True

* add tests checking sdpa and eager outputs match when output_attentions=True

* Split if statements in two lines

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix formatting

* Add fix to new jetmoe model

* Add missing output_attentions argument to jetmoe mask creation

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@gante
Copy link
Member

gante commented May 28, 2024

@Aladoro thank you for detecting the issue and making transformers better for all of us 💛

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
…ut_attentions=True (huggingface#30652)

* Fix llama model forward function with attention=True, same-length encoded sequence.

* Fix style

* propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama)

* Fix style

* ignore unnecessary sdpa mask converter when output_attentions=True

* add tests checking sdpa and eager outputs match when output_attentions=True

* Split if statements in two lines

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix formatting

* Add fix to new jetmoe model

* Add missing output_attentions argument to jetmoe mask creation

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants